'''
This implements the classic REINFORCE algorith, which is the most basic and fundamental way to do 
policy-gradient. Policy-gradient methods are, of course, 
how modern LLMs are post-trained (where LLM is the policy state -> actions). 

I'm not going to derive the math here, but the basic idea is that we can measure how good our policy is 
by looking at how much reward each rollout (action sequence guided by our policy) generates. 
Then, we have to judge/infer how much each individual action in the rollout (s0, a0, s1, a1...)
helped. Doing so is often called solving the "credit assignment problem" and here since each step we stay alive/upright 
we get reward one, it's easy. 

Then, we update our policy to increase the likelihood of actions that led to positive reward, and vice versa for negative reward. 
This leads to the update rule Loss = -(reward_action * prob_action) where the negative sign is because loss = -reward since we 
want to minimize it. 

One subtlety is that we don't actually use the reward directly act each action, but "returns" which are time-discounted 
version of reward, to account for the fact that reward now is more valuable than reward far in the future, 
since it's not guaranteed we will survive to the future to get that reward, and thus we should downweight it. 
IMO, modern policy-gradient methods are essentially REINFORCE++ and deeply understanding this algorithm is kind of enough 
to algorithmically get a sense of modern policy-gradient algorithms like PPO/GRPO. In practice in RL, 
modeling reward well (in a cheap to compute and unhackable way) is the hard part and details like which policy 
optimization algorithm you use, though important, matter much less than quality of reward modeling. 
'''


import gym 
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm 
import argparse

env = gym.make('CartPole-v1')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PolicyNet(nn.Module): # states -> action probs 
    def __init__(self, nstates=4, nactions=2, hidden_dim=128, act=nn.GELU()):
        super().__init__()
        self.w1 = nn.Linear(nstates, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, nactions)
        self.act = act

    def forward(self, x): # x is [nstates], want [nactions] as logits
        return self.w2(self.act(self.w1(x)))

def loss_fn(rewards, logprobs, mask, gamma=0.99, max_rollout_len=50): # first two are [b, ms]
    # transform r into returns R_t with exponential discounting using a convolution
        # this is a little subtle so you should make sure you know what the discounting matrix G looks like 
        # and how it was constructed using torch broadcasting as well as why doing a matmul with it 
        # takes rewards -> returns in the correct way. these details are critical to understand!
    idx = torch.arange(max_rollout_len, device=rewards.device)
    power_mat = idx[:,None] - idx[None,:]     # (k,t)=k−t
    G = torch.tril(gamma**power_mat)  # already lower‐triangular
    returns = rewards @ G # [b, msl] @ [msl, msl] -> [b, msl]

    # interesting note: if you don't normalize returns, loss INCREASES
    # even though the policy is IMPROVING (higher reward)
    # because avg ep length increases and loss_t is increasing in ep len 
    masked_returns = mask * returns
    mean_returns = masked_returns.sum() / mask.sum()
    # var computed old fashion way, ie. Var[X] = E[X^2] - [EX]^2
    # where EX = 0 so Var[X] = E[X^2] where X is centered
    centered_masked_returns = masked_returns - mean_returns
    var_returns = (centered_masked_returns ** 2).sum() / mask.sum()
    std_returns = torch.sqrt(var_returns)
    fully_normalized_returns = mean_returns / (std_returns + 1e-8)

    # return -Reward = -E[R_t * logprob] as loss 
    loss_t = -mask * fully_normalized_returns * logprobs
    return loss_t.sum() / mask.sum()

def train(nsteps=100, batch_size=64, max_rollout_len=50, lr=1e-3, gamma=0.99, verbose=False):
    policy = PolicyNet().to(device)
    opt = torch.optim.AdamW(policy.parameters(), lr=lr)
    opt.zero_grad()
    done = False 
    b = batch_size
    
    if verbose:
        print(f"Starting training with {nsteps} steps, batch size {batch_size}, max rollout length {max_rollout_len}")
        print(f"Learning rate: {lr}, discount factor (gamma): {gamma}")
        print(f"Policy network: {policy}")
    
    for step in tqdm(range(nsteps)):
        # on-policy, ie. the batch we'll step on a few times in training step has JUST been generated by our current policy 
        batch_rewards, batch_logprobs = [], []  # both will be [b, sm]
        true_lens = torch.zeros(b)  # entry i is true len of batch element i, use this with arange to mask mask of size [b, s]
        
        if verbose and step % 5 == 0:
            print(f"\nGenerating batch {step} of rollouts...")
            
        for batch_idx in range(b):  # in contrast to dqn which stores a buffer where we may be learning from 
            i = 0 
            rollout_rewards, rollout_logprobs = torch.zeros(max_rollout_len), torch.zeros(max_rollout_len)
            done = False 
            state, _ = env.reset()
            state = torch.tensor(state).to(device)
            # generate a single rollout
            while not done and i < max_rollout_len:
                dist = F.softmax(policy(state), dim=-1)
                logprobs = torch.log(dist)
                next_action = torch.multinomial(dist, 1).item()

                next_state, r, terminated, truncated, _ = env.step(next_action)
                done = terminated or truncated
                rollout_rewards[i] = r
                rollout_logprobs[i] = logprobs[next_action]

                state = torch.tensor(next_state).to(device)
                i += 1
                
                if done or i == max_rollout_len:
                    true_lens[batch_idx] = i 
                    if verbose and batch_idx % 10 == 0 and step % 5 == 0:
                        print(f"  Rollout {batch_idx}: length {i}, final reward {r}, terminated: {terminated}")

            # add tensor to list to stack later 
            batch_rewards.append(rollout_rewards)
            batch_logprobs.append(rollout_logprobs)

        # stack logprobs and r into two tensors, they are list of lists overall [b, mrl]
        batch_rewards = torch.stack(batch_rewards).to(device)
        batch_logprobs = torch.stack(batch_logprobs).to(device)

        # rollout.shape, true_lens.shape # [2, b, s]
        mask = (torch.arange(max_rollout_len, device=device)[None]
                < true_lens[:, None].to(device)).float()  # [b, s], this is some cute tensor golf you should make sure to understand

        if verbose and step % 5 == 0:
            print(f"Computing loss for batch {step}...")
            
        loss = loss_fn(batch_rewards, batch_logprobs, mask, gamma=gamma, max_rollout_len=max_rollout_len)
        loss.backward()    
        opt.step()
        opt.zero_grad()

        if step % 5 == 0 and verbose:
            avg_length = true_lens.mean().item()
            min_length = true_lens.min().item()
            max_length = true_lens.max().item()
            print(f"[{step:4d}/{nsteps:4d}]  ||   Loss = {loss.item():.4f}  ||  Reward(Avg Ep Len) = {avg_length:.1f}  ||  Min/Max Len = {min_length:.1f}/{max_length:.1f}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train REINFORCE on CartPole')
    parser.add_argument('--nsteps', type=int, default=500, help='Number of training steps')
    parser.add_argument('--batch-size', type=int, default=256, help='Batch size')
    parser.add_argument('--max-rollout-len', type=int, default=200, help='Maximum rollout length')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor')
    parser.add_argument('--hidden-dim', type=int, default=128, help='Hidden dimension size for policy network')
    parser.add_argument('--early-stop', action='store_true', help='Stop training once environment is solved (avg reward > 195)')
    parser.add_argument('--wandb', action='store_true', help='Use wandb logging')
    parser.add_argument('--verbose', action='store_true', help='Print training progress')
    
    args = parser.parse_args()
    
    if args.verbose:
        print("Starting REINFORCE training on CartPole-v1 environment")
        print(f"Arguments: {args}")
    
    train(
        nsteps=args.nsteps,
        batch_size=args.batch_size,
        max_rollout_len=args.max_rollout_len,
        lr=args.lr,
        gamma=args.gamma,
        verbose=args.verbose
    )